{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "fEwwyiWMlncS"
   },
   "source": [
    "##### Copyright 2020 The TensorFlow Authors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "colab": {},
    "colab_type": "code",
    "id": "lT4CrDemeqzN"
   },
   "outputs": [],
   "source": [
    "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "w5BqikBFR8Jb"
   },
   "source": [
    "# Writing your own callbacks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "gYNefYV3toRo"
   },
   "source": [
    "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/keras/custom_callback\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/keras-team/keras-io/blob/master/tf/custom_callback.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://github.com/keras-team/keras-io/blob/master/guides/writing_your_own_callbacks.py\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a href=\"https://storage.googleapis.com/tensorflow_docs/keras-io/tf/custom_callback.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
    "  </td>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "M80XSuTIvJld"
   },
   "source": [
    "## Introduction\n",
    "\n",
    "A callback is a powerful tool to customize the behavior of a Keras model during\n",
    "training, evaluation, or inference. Examples include `tf.keras.callbacks.TensorBoard`\n",
    "to visualize training progress and results with TensorBoard, or\n",
    "`tf.keras.callbacks.ModelCheckpoint` to periodically save your model during training.\n",
    "\n",
    "In this guide, you will learn what a Keras callback is, what it can do, and how you can\n",
    "build your own. We provide a few demos of simple callback applications to get you\n",
    "started."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "m6Tx5DcPg1o6"
   },
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code",
    "id": "UxOi7EEE6Q9Q"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow import keras"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "DjgZ3jSko2kU"
   },
   "source": [
    "## Keras callbacks overview\n",
    "\n",
    "All callbacks subclass the `keras.callbacks.Callback` class, and\n",
    "override a set of methods called at various stages of training, testing, and\n",
    "predicting. Callbacks are useful to get a view on internal states and statistics of\n",
    "the model during training.\n",
    "\n",
    "You can pass a list of callbacks (as the keyword argument `callbacks`) to the following\n",
    "model methods:\n",
    "\n",
    "- `keras.Model.fit()`\n",
    "- `keras.Model.evaluate()`\n",
    "- `keras.Model.predict()`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "VpW1Q0YjwH4s"
   },
   "source": [
    "## An overview of callback methods\n",
    "\n",
    "### Global methods\n",
    "\n",
    "#### `on_(train|test|predict)_begin(self, logs=None)`\n",
    "\n",
    "Called at the beginning of `fit`/`evaluate`/`predict`.\n",
    "\n",
    "#### `on_(train|test|predict)_end(self, logs=None)`\n",
    "\n",
    "Called at the end of `fit`/`evaluate`/`predict`.\n",
    "\n",
    "### Batch-level methods for training/testing/predicting\n",
    "\n",
    "#### `on_(train|test|predict)_batch_begin(self, batch, logs=None)`\n",
    "\n",
    "Called right before processing a batch during training/testing/predicting.\n",
    "\n",
    "#### `on_(train|test|predict)_batch_end(self, batch, logs=None)`\n",
    "\n",
    "Called at the end of training/testing/predicting a batch. Within this method, `logs` is\n",
    "a dict containing the metrics results.\n",
    "\n",
    "### Epoch-level methods (training only)\n",
    "\n",
    "#### `on_epoch_begin(self, epoch, logs=None)`\n",
    "\n",
    "Called at the beginning of an epoch during training.\n",
    "\n",
    "#### `on_epoch_end(self, epoch, logs=None)`\n",
    "\n",
    "Called at the end of an epoch during training."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "wt9nVKBw2Zbd"
   },
   "source": [
    "## A basic example\n",
    "\n",
    "Let's take a look at a concrete example. To get started, let's import tensorflow and\n",
    "define a simple Sequential Keras model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code",
    "id": "pE2CcBRVisfl"
   },
   "outputs": [],
   "source": [
    "# Define the Keras model to add callbacks to\n",
    "def get_model():\n",
    "    model = keras.Sequential()\n",
    "    model.add(keras.layers.Dense(1, input_dim=784))\n",
    "    model.compile(\n",
    "        optimizer=keras.optimizers.RMSprop(learning_rate=0.1),\n",
    "        loss=\"mean_squared_error\",\n",
    "        metrics=[\"mean_absolute_error\"],\n",
    "    )\n",
    "    return model\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "csY3ImU6n4AE"
   },
   "source": [
    "Then, load the MNIST data for training and testing from Keras datasets API:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code",
    "id": "m39zXsmb6brk"
   },
   "outputs": [],
   "source": [
    "# Load example MNIST data and pre-process it\n",
    "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
    "x_train = x_train.reshape(-1, 784).astype(\"float32\") / 255.0\n",
    "x_test = x_test.reshape(-1, 784).astype(\"float32\") / 255.0\n",
    "\n",
    "# Limit the data to 1000 samples\n",
    "x_train = x_train[:1000]\n",
    "y_train = y_train[:1000]\n",
    "x_test = x_test[:1000]\n",
    "y_test = y_test[:1000]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Ah1Hsy79P0tI"
   },
   "source": [
    "Now, define a simple custom callback that logs:\n",
    "\n",
    "- When `fit`/`evaluate`/`predict` starts & ends\n",
    "- When each epoch starts & ends\n",
    "- When each training batch starts & ends\n",
    "- When each evaluation (test) batch starts & ends\n",
    "- When each inference (prediction) batch starts & ends"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code",
    "id": "PVzRdrXuQYo7"
   },
   "outputs": [],
   "source": [
    "class CustomCallback(keras.callbacks.Callback):\n",
    "    def on_train_begin(self, logs=None):\n",
    "        keys = list(logs.keys())\n",
    "        print(\"Starting training; got log keys: {}\".format(keys))\n",
    "\n",
    "    def on_train_end(self, logs=None):\n",
    "        keys = list(logs.keys())\n",
    "        print(\"Stop training; got log keys: {}\".format(keys))\n",
    "\n",
    "    def on_epoch_begin(self, epoch, logs=None):\n",
    "        keys = list(logs.keys())\n",
    "        print(\"Start epoch {} of training; got log keys: {}\".format(epoch, keys))\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        keys = list(logs.keys())\n",
    "        print(\"End epoch {} of training; got log keys: {}\".format(epoch, keys))\n",
    "\n",
    "    def on_test_begin(self, logs=None):\n",
    "        keys = list(logs.keys())\n",
    "        print(\"Start testing; got log keys: {}\".format(keys))\n",
    "\n",
    "    def on_test_end(self, logs=None):\n",
    "        keys = list(logs.keys())\n",
    "        print(\"Stop testing; got log keys: {}\".format(keys))\n",
    "\n",
    "    def on_predict_begin(self, logs=None):\n",
    "        keys = list(logs.keys())\n",
    "        print(\"Start predicting; got log keys: {}\".format(keys))\n",
    "\n",
    "    def on_predict_end(self, logs=None):\n",
    "        keys = list(logs.keys())\n",
    "        print(\"Stop predicting; got log keys: {}\".format(keys))\n",
    "\n",
    "    def on_train_batch_begin(self, batch, logs=None):\n",
    "        keys = list(logs.keys())\n",
    "        print(\"...Training: start of batch {}; got log keys: {}\".format(batch, keys))\n",
    "\n",
    "    def on_train_batch_end(self, batch, logs=None):\n",
    "        keys = list(logs.keys())\n",
    "        print(\"...Training: end of batch {}; got log keys: {}\".format(batch, keys))\n",
    "\n",
    "    def on_test_batch_begin(self, batch, logs=None):\n",
    "        keys = list(logs.keys())\n",
    "        print(\"...Evaluating: start of batch {}; got log keys: {}\".format(batch, keys))\n",
    "\n",
    "    def on_test_batch_end(self, batch, logs=None):\n",
    "        keys = list(logs.keys())\n",
    "        print(\"...Evaluating: end of batch {}; got log keys: {}\".format(batch, keys))\n",
    "\n",
    "    def on_predict_batch_begin(self, batch, logs=None):\n",
    "        keys = list(logs.keys())\n",
    "        print(\"...Predicting: start of batch {}; got log keys: {}\".format(batch, keys))\n",
    "\n",
    "    def on_predict_batch_end(self, batch, logs=None):\n",
    "        keys = list(logs.keys())\n",
    "        print(\"...Predicting: end of batch {}; got log keys: {}\".format(batch, keys))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "zaOKC7dpK7ng"
   },
   "source": [
    "Let's try it out:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code",
    "id": "IiFWfxsNkypE"
   },
   "outputs": [],
   "source": [
    "model = get_model()\n",
    "model.fit(\n",
    "    x_train,\n",
    "    y_train,\n",
    "    batch_size=128,\n",
    "    epochs=1,\n",
    "    verbose=0,\n",
    "    validation_split=0.5,\n",
    "    callbacks=[CustomCallback()],\n",
    ")\n",
    "\n",
    "res = model.evaluate(\n",
    "    x_test, y_test, batch_size=128, verbose=0, callbacks=[CustomCallback()]\n",
    ")\n",
    "\n",
    "res = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "y9oQfhHERRsm"
   },
   "source": [
    "### Usage of `logs` dict\n",
    "The `logs` dict contains the loss value, and all the metrics at the end of a batch or\n",
    "epoch. Example includes the loss and mean absolute error."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code",
    "id": "dVErv1llBq0l"
   },
   "outputs": [],
   "source": [
    "class LossAndErrorPrintingCallback(keras.callbacks.Callback):\n",
    "    def on_train_batch_end(self, batch, logs=None):\n",
    "        print(\"For batch {}, loss is {:7.2f}.\".format(batch, logs[\"loss\"]))\n",
    "\n",
    "    def on_test_batch_end(self, batch, logs=None):\n",
    "        print(\"For batch {}, loss is {:7.2f}.\".format(batch, logs[\"loss\"]))\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        print(\n",
    "            \"The average loss for epoch {} is {:7.2f} \"\n",
    "            \"and mean absolute error is {:7.2f}.\".format(\n",
    "                epoch, logs[\"loss\"], logs[\"mean_absolute_error\"]\n",
    "            )\n",
    "        )\n",
    "\n",
    "\n",
    "model = get_model()\n",
    "model.fit(\n",
    "    x_train,\n",
    "    y_train,\n",
    "    batch_size=128,\n",
    "    epochs=2,\n",
    "    verbose=0,\n",
    "    callbacks=[LossAndErrorPrintingCallback()],\n",
    ")\n",
    "\n",
    "res = model.evaluate(\n",
    "    x_test,\n",
    "    y_test,\n",
    "    batch_size=128,\n",
    "    verbose=0,\n",
    "    callbacks=[LossAndErrorPrintingCallback()],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "AC03ymKbxZNq"
   },
   "source": [
    "## Usage of `self.model` attribute\n",
    "\n",
    "In addition to receiving log information when one of their methods is called,\n",
    "callbacks have access to the model associated with the current round of\n",
    "training/evaluation/inference: `self.model`.\n",
    "\n",
    "Here are of few of the things you can do with `self.model` in a callback:\n",
    "\n",
    "- Set `self.model.stop_training = True` to immediately interrupt training.\n",
    "- Mutate hyperparameters of the optimizer (available as `self.model.optimizer`),\n",
    "such as `self.model.optimizer.learning_rate`.\n",
    "- Save the model at period intervals.\n",
    "- Record the output of `model.predict()` on a few test samples at the end of each\n",
    "epoch, to use as a sanity check during training.\n",
    "- Extract visualizations of intermediate features at the end of each epoch, to monitor\n",
    "what the model is learning over time.\n",
    "- etc.\n",
    "\n",
    "Let's see this in action in a couple of examples."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "H2GuBD9Pl7Hg"
   },
   "source": [
    "## Examples of Keras callback applications"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "kR5WkzikUzX8"
   },
   "source": [
    "### Early stopping at minimum loss\n",
    "\n",
    "This first example shows the creation of a `Callback` that stops training when the\n",
    "minimum of loss has been reached, by setting the attribute `self.model.stop_training`\n",
    "(boolean). Optionally, you can provide an argument `patience` to specify how many\n",
    "epochs we should wait before stopping after having reached a local minimum.\n",
    "\n",
    "`tf.keras.callbacks.EarlyStopping` provides a more complete and general implementation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code",
    "id": "lbTowW8Tbk8Q"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "class EarlyStoppingAtMinLoss(keras.callbacks.Callback):\n",
    "    \"\"\"Stop training when the loss is at its min, i.e. the loss stops decreasing.\n",
    "\n",
    "  Arguments:\n",
    "      patience: Number of epochs to wait after min has been hit. After this\n",
    "      number of no improvement, training stops.\n",
    "  \"\"\"\n",
    "\n",
    "    def __init__(self, patience=0):\n",
    "        super(EarlyStoppingAtMinLoss, self).__init__()\n",
    "        self.patience = patience\n",
    "        # best_weights to store the weights at which the minimum loss occurs.\n",
    "        self.best_weights = None\n",
    "\n",
    "    def on_train_begin(self, logs=None):\n",
    "        # The number of epoch it has waited when loss is no longer minimum.\n",
    "        self.wait = 0\n",
    "        # The epoch the training stops at.\n",
    "        self.stopped_epoch = 0\n",
    "        # Initialize the best as infinity.\n",
    "        self.best = np.Inf\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        current = logs.get(\"loss\")\n",
    "        if np.less(current, self.best):\n",
    "            self.best = current\n",
    "            self.wait = 0\n",
    "            # Record the best weights if current results is better (less).\n",
    "            self.best_weights = self.model.get_weights()\n",
    "        else:\n",
    "            self.wait += 1\n",
    "            if self.wait >= self.patience:\n",
    "                self.stopped_epoch = epoch\n",
    "                self.model.stop_training = True\n",
    "                print(\"Restoring model weights from the end of the best epoch.\")\n",
    "                self.model.set_weights(self.best_weights)\n",
    "\n",
    "    def on_train_end(self, logs=None):\n",
    "        if self.stopped_epoch > 0:\n",
    "            print(\"Epoch %05d: early stopping\" % (self.stopped_epoch + 1))\n",
    "\n",
    "\n",
    "model = get_model()\n",
    "model.fit(\n",
    "    x_train,\n",
    "    y_train,\n",
    "    batch_size=64,\n",
    "    steps_per_epoch=5,\n",
    "    epochs=30,\n",
    "    verbose=0,\n",
    "    callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "TU1xXEs9VES9"
   },
   "source": [
    "### Learning rate scheduling\n",
    "\n",
    "In this example, we show how a custom Callback can be used to dynamically change the\n",
    "learning rate of the optimizer during the course of training.\n",
    "\n",
    "See `callbacks.LearningRateScheduler` for a more general implementations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code",
    "id": "aOhw3UB85Vao"
   },
   "outputs": [],
   "source": [
    "class CustomLearningRateScheduler(keras.callbacks.Callback):\n",
    "    \"\"\"Learning rate scheduler which sets the learning rate according to schedule.\n",
    "\n",
    "  Arguments:\n",
    "      schedule: a function that takes an epoch index\n",
    "          (integer, indexed from 0) and current learning rate\n",
    "          as inputs and returns a new learning rate as output (float).\n",
    "  \"\"\"\n",
    "\n",
    "    def __init__(self, schedule):\n",
    "        super(CustomLearningRateScheduler, self).__init__()\n",
    "        self.schedule = schedule\n",
    "\n",
    "    def on_epoch_begin(self, epoch, logs=None):\n",
    "        if not hasattr(self.model.optimizer, \"lr\"):\n",
    "            raise ValueError('Optimizer must have a \"lr\" attribute.')\n",
    "        # Get the current learning rate from model's optimizer.\n",
    "        lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate))\n",
    "        # Call schedule function to get the scheduled learning rate.\n",
    "        scheduled_lr = self.schedule(epoch, lr)\n",
    "        # Set the value back to the optimizer before this epoch starts\n",
    "        tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)\n",
    "        print(\"\\nEpoch %05d: Learning rate is %6.4f.\" % (epoch, scheduled_lr))\n",
    "\n",
    "\n",
    "LR_SCHEDULE = [\n",
    "    # (epoch to start, learning rate) tuples\n",
    "    (3, 0.05),\n",
    "    (6, 0.01),\n",
    "    (9, 0.005),\n",
    "    (12, 0.001),\n",
    "]\n",
    "\n",
    "\n",
    "def lr_schedule(epoch, lr):\n",
    "    \"\"\"Helper function to retrieve the scheduled learning rate based on epoch.\"\"\"\n",
    "    if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:\n",
    "        return lr\n",
    "    for i in range(len(LR_SCHEDULE)):\n",
    "        if epoch == LR_SCHEDULE[i][0]:\n",
    "            return LR_SCHEDULE[i][1]\n",
    "    return lr\n",
    "\n",
    "\n",
    "model = get_model()\n",
    "model.fit(\n",
    "    x_train,\n",
    "    y_train,\n",
    "    batch_size=64,\n",
    "    steps_per_epoch=5,\n",
    "    epochs=15,\n",
    "    verbose=0,\n",
    "    callbacks=[\n",
    "        LossAndErrorPrintingCallback(),\n",
    "        CustomLearningRateScheduler(lr_schedule),\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Fwr7dO76pJr9"
   },
   "source": [
    "### Built-in Keras callbacks\n",
    "Be sure to check out the existing Keras callbacks by\n",
    "reading the [API docs](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/).\n",
    "Applications include logging to CSV, saving\n",
    "the model, visualizing metrics in TensorBoard, and a lot more!"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "custom_callback.ipynb",
   "private_outputs": true,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}